import os
import pickle

MAX_LEN = 7
from mydataset import *
class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {"SOS":0,"EOS":1,"PAD":2,"UNK":3}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS",2:"PAD",3:"UNK"}
        self.n_words = 4  # Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

def creat_lang(max_num):

    input_lang = Lang("问题字典")
    out_lang = Lang("答案字典")

    with open("data/seq.data","r",encoding="utf-8") as f:
        total_data = f.readlines()

    for line in total_data:
        data_c = line.strip().split("\t")
        if len(data_c)==2 and len(data_c[0].split(" "))>4 and len(data_c[1].split(" ")) >4:
            input_c = data_c[0]
            output_c = data_c[1]
            input_lang.addSentence(input_c)
            out_lang.addSentence(output_c)

    # 保存字典
    with open("dict/input_lang.pkl", 'wb') as f:
        pickle.dump(input_lang, f)
    with open("dict/out_lang.pkl", 'wb') as f:
        pickle.dump(out_lang, f)
    return input_lang,out_lang

# 获取train_data
def tensorfromsentenct(sentenct,lang,tag="input"):
    words = sentenct.split(" ")
    ids= [lang.word2index.get(word,3) for word in words] # 获取ward的id，没有的花返回unk id ：3
    if tag=="input":
        if len(ids) < MAX_LEN:
            ids = ids + [2]*(MAX_LEN-len(ids))# 短于max_len 填充
        else:
            ids = ids[:MAX_LEN]
    if tag=="tag_input":
        ids = [0]+ids
        if len(ids) < MAX_LEN:
            ids = ids + [2]*(MAX_LEN-len(ids)) # 短于max_len 填充
        else:
            ids = ids[:MAX_LEN]
    if tag=="tag_output":

        if len(ids) < MAX_LEN-1:
            ids = ids+[1]
            ids = ids + [2]*(MAX_LEN-len(ids)) # 短于max_len 填充
        else:
            ids = ids[:MAX_LEN-1]
            ids = ids+[1]
    return ids



def read_data(input_lang,out_lang,data_path="data/seq.data",num_max=500):
    with open(data_path, "r", encoding="utf-8") as f:
        # total_data = f.readlines()[:num_max]
        total_data = f.readlines()

    input_data = []
    tag_input=[]
    tag_output=[]
    for line in total_data:
        data_c = line.strip().split("\t")
        if len(data_c) == 2 and len(data_c[0].split(" "))>4 and len(data_c[1].split(" ")) >4:
            input_c = data_c[0]
            output_c = data_c[1]

            input_data.append(tensorfromsentenct(input_c,input_lang,"input"))
            tag_input.append(tensorfromsentenct(output_c,out_lang,"tag_input"))
            tag_output.append(tensorfromsentenct(output_c,out_lang,"tag_output"))
    return input_data,tag_input,tag_output


















